LstmGradWeight
计算 LSTM 网络的权重梯度,包括输入权重 W、隐藏权重 U 和偏置 b 的反向传播梯度。
\[\begin{split}dH_t &= dY_t + dH_{t+1} \\
dC_t &= dC_{t+1} \odot f_{t+1} + dH_t \odot o_t \odot (1 - \tanh^2(C_t)) \\
dX_t &= dA_t \cdot W^T \\
dW &= \sum_t dA_t \cdot X_t^T \\
dU &= \sum_t dA_t \cdot H_{t-1}^T \\
dA_t &= dH_t \odot o_t \odot (1 - \tanh^2(C_t)) \odot g'_t\end{split}\]
其中:
(dH_t) 表示隐藏状态的梯度。
(dC_t) 表示细胞状态的梯度。
(dX_t) 表示输入梯度。
(dW, dU) 分别表示输入权重和隐藏状态权重的梯度。
(dA_t) 表示门控单元梯度。
(f_t, o_t, C_t, g_t) 分别为遗忘门、输出门、细胞状态、输入门的前向值。
(odot) 表示元素逐乘。
- 输入:
params - 静态参数数组,包含 LSTM 网络配置、权重、状态指针等。
dynamic_params - 动态参数数组,用于存储运行时指针及中间梯度。
- 输出:
dX_ - 输入梯度。
dH_ - 隐藏状态梯度。
dC_ - 细胞状态梯度。
dA_tmp_ - 门控梯度中间缓存。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持 fp
MT7004 支持 hp, fp
共享存储版本:
-
void fp_lstmgradweight_s(long long *params, long long *dynamic_params, int core_mask)
-
void hp_lstmgradweight_s(long long *params, long long *dynamic_params, int core_mask)
C调用示例:
1#include <stdio.h>
2#include <lstmgradweight.h>
3
4int main() {
5 long long params[32];
6 long long dynamic_params[32];
7 int core_mask = 0xff;
8
9 // 初始化 params 和 dynamic_params
10 fp_lstmgradweight_s(params, dynamic_params, core_mask);
11 return 0;
12}
私有存储版本:
-
void fp_lstmgradweight_p(long long *params, long long *dynamic_params)
-
void hp_lstmgradweight_p(long long *params, long long *dynamic_params)
C调用示例:
1#include <stdio.h>
2#include <lstmgradweight.h>
3
4int main() {
5 long long params[32];
6 long long dynamic_params[32];
7
8 // 初始化 params 和 dynamic_params
9 fp_lstmgradweight_p(params, dynamic_params);
10 return 0;
11}